package com.whoyx.jiebing.utils.nlp;

import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import com.whoyx.jiebing.config.NlpFilePathConfig;
import com.whoyx.jiebing.utils.nlp.lac_sdk.Lac;
import com.whoyx.jiebing.utils.nlp.sentence_encoder.SentenceEncoder;
import com.whoyx.jiebing.utils.nlp.word_encoder.WordEncoder;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import me.aias.Jieba;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

/**
 * @author stone
 */
@Slf4j
@Component
@RequiredArgsConstructor
public class NlpUtils {

    private final NlpFilePathConfig pathConfig;
    private final Lac lac;
    private final SentenceEncoder sentenceEncoder;

    /**
     * 分词
     */
    public String[] lac(String input) {
        Criteria<String, String[][]> criteria = lac.criteria();
        try (ZooModel<String, String[][]> model = criteria.loadModel();
             Predictor<String, String[][]> predictor = model.newPredictor()) {

            String[][] result = predictor.predict(input);
            log.info("Words : " + Arrays.toString(result[0]));
            log.info("Tags : " + Arrays.toString(result[1]));
            return result[0];
        } catch (ModelNotFoundException | IOException | MalformedModelException | TranslateException e) {
            throw new RuntimeException(e);
        }
    }

    public String[] jieBa(String input) {
        Jieba parser = new Jieba();
        String[] result = parser.cut(input);
        if (result.length > 0) {
            return result;
        }
        return null;
    }

    /**
     * 词向量
     */
    public void wordEncoder(String sentence, String keyWord, List<WordSimilarity> wordSimilarityList) {
        if (StringUtils.isBlank(sentence) || StringUtils.isBlank(keyWord) || wordSimilarityList == null) {
            return;
        }
//        String[] words = lac(sentence);
        String[] words = jieBa(sentence);
        if (null == words || words.length == 0) {
            return;
        }
        String[] keyWords = keyWord.split(",");

        Path vocabPath = Paths.get(pathConfig.getWordEncoderNpyPath());
        Path embeddingPath = Paths.get("lib/w2v_zhihu_dim300.npy");

        WordEncoder encoder = null;
        try {
            encoder = new WordEncoder(vocabPath, embeddingPath);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        for (String w : words) {
            // 获取单词的特征值embedding
            float[] embedding1 = encoder.search(w);
            log.info("{}-特征值: {}", w, Arrays.toString(embedding1));
            for (String kw : keyWords) {
                float[] embedding2 = encoder.search(kw);
                log.info("{}-特征值: {}", kw, Arrays.toString(embedding1));
                // 计算两个词向量的余弦相似度
                float cosineSim = FeatureComparison.cosineSim(embedding1, embedding2);
                log.info("{}-余弦相似度: {}", w + "-" + kw, cosineSim);
                // 计算两个词向量的内积
                float dot = FeatureComparison.dot(embedding1, embedding2);
                log.info("{}-内积: {}", w + "-" + kw, dot);

                wordSimilarityList.add(WordSimilarity.builder()
                        .word(w)
                        .keyWord(keyWord)
                        .cosineSim(cosineSim)
                        .dot(dot)
                        .build());
            }
        }
    }

    public void sentenceEncoder(String sentence, String[] keySentences) {
        if (StringUtils.isBlank(sentence) || null == keySentences || keySentences.length == 0) {
            return;
        }
        try (ZooModel<String, float[]> model = ModelZoo.loadModel(sentenceEncoder.criteria());
             Predictor<String, float[]> predictor = model.newPredictor()) {
            float[] embeddings1 = predictor.predict(sentence);
            for (String s : keySentences) {
                float[] embeddings2 = predictor.predict(s);
                float cosineSim = FeatureComparison.cosineSim(embeddings1, embeddings2);
                log.info("{}-余弦相似度: {}", sentence + "-" + s, cosineSim);
            }
        } catch (ModelNotFoundException | MalformedModelException | IOException | TranslateException e) {
            throw new RuntimeException(e);
        }
    }

    public static void main(String[] args) {

        String keyWord = "狂风暴雨";
        String input1 = "医生我看到您最近门诊已经约满了，可我的复诊时间已经提交了，想咨询一下怎么去修改复诊时间？";
        List<WordSimilarity> list = new ArrayList<>();
        String[] input3 = {"预约", "复查", "申请", "复查申请", "取消门诊",
                "预约通知", "复诊", "复诊检查", "手术", "改约", "修改预约",
                "使用", "使用规则", "平台使用规则", "挂号", "挂号", "挂号", "挂号"};
        List<String> collect = Arrays.stream(input3).distinct().collect(Collectors.toList());
        String join = StringUtils.join(collect, ",");

//        sentenceEncoder(input1,input3);sentenceEncoder(input1,input3);
//        wordEncoder.wordEncoder(input1, join,list);
//        System.out.println(list);

        String[] input2 = {"今天风和日丽", "今天天气不好"};
//        sentenceEncoder(input1, input2);
//        String[] lac = lac(input1);
//        System.out.println(lac);


    }

}
